-
Notifications
You must be signed in to change notification settings - Fork 445
[data][feat] add large dataset support #1298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[data][feat] add large dataset support #1298
Conversation
24fb033 to
7d855c6
Compare
|
@codex review |
|
To use Codex here, create a Codex account and connect to github. |
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 8ee19cc3fa
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Store original content length before padding for correct unpacking | ||
| packed_sequence["content_length"] = seq_length | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preserve original content_length across CP padding
In CP mode, _get_model_inputs_args pads the same packed batch multiple times (e.g., once during _compute_log_prob and again during _train_step). pad_packed_sequence_with_cp overwrites content_length unconditionally, so on the second call it records the already padded length. unpack_sequences then computes pad_length as 0 and slices log_probs/entropy for the last sequence using the padded end index, which drops real tokens and includes padding when cp_size > 1, corrupting the loss/metrics for that sequence. Consider only setting content_length if it’s absent or skipping padding when the batch is already padded.
Useful? React with 👍 / 👎.
PopSoda2002
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your great work! I am not sure if we need to seperate datasets and SFT into different PRs, currently this PR seems too large for reviewing
slime/backends/fsdp_utils/actor.py
Outdated
| if self.args.calculate_per_token_loss: | ||
| sft_loss = sum_of_token(sft_loss, response_lengths, loss_masks) | ||
| else: | ||
| sft_loss = sum_of_sample_mean(sft_loss, response_lengths, loss_masks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we just need to calculate_per_sample_loss in SFT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After further discussion, we realized that SFT should only use per-token loss.
We’ll simplify this logic by keeping token loss only here.
For users who still try to use sequence / per-sample loss in SFT, we’ll explicitly raise an error to avoid silent misconfiguration.
slime/rollout/data_source.py
Outdated
| seed=self.args.rollout_seed, | ||
| apply_chat_template=self.args.apply_chat_template, | ||
| apply_chat_template_kwargs=self.args.apply_chat_template_kwargs, | ||
| dp_size=self._dp_size or 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not pass dp size to the data source because the data source is used in rollout manager, which does not have dp ranks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not pass dp size to the data source because the data source is used in rollout manager, which does not have dp ranks.
Apologies for the mistake, I’ll address it in the upcoming commits.
I’ll work on some simplifications with @ChangyiYang today. Afterward, could you review it again and discuss if we should split this PR into two? |
Yeah sure, definitely, always willing for help |
Thank you for the efforts on radixark/miles#246 and the hard work contributed by @Ratish1!